import torch
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score, precision_score, recall_score, f1_score, matthews_corrcoef

__all__ = ['RLCExperiment']


class RLCExperiment(object):
    def __init__(self, model, optimizer, reporter, experiment_dir, device, num_epochs=300, early_stopping=0, test_frequency=10, **kwargs):
        self.model = model
        self.optimizer = optimizer
        self.reporter = reporter
        self.experiment_dir = experiment_dir
        self.device = device

        self.num_epochs = num_epochs
        self.early_stopping = early_stopping
        self.test_frequency = test_frequency

    def train_epoch(self, data_loader):
        self.model.train()
        self.model.reset_state()
        total_loss = 0
        predictions, targets = None, None
        for batch in data_loader:
            self.optimizer.zero_grad()
            loss, p, t = self.model(batch)
            loss.backward()
            self.optimizer.step()
            self.model.detach()
            total_loss += loss.item()
            if predictions is None:
                predictions, targets = p, t
            else:
                predictions = torch.cat([predictions, p], dim=0)
                targets = torch.cat([targets, t], dim=0)
        return total_loss, predictions, targets

    def evaluate_epoch(self, data_loader):
        with torch.no_grad():
            self.model.eval()
            total_loss = 0
            predictions, targets = None, None
            for batch in data_loader:
                loss, p, t = self.model(batch)
                total_loss += loss.item()
                if predictions is None:
                    predictions, targets = p, t
                else:
                    predictions = torch.cat([predictions, p], dim=0)
                    targets = torch.cat([targets, t], dim=0)
        return total_loss, predictions, targets

    def train(self, train_loader, val_loader, test_loader):
        stop = False
        train_losses, val_losses, test_losses = [], [], []
        best_metrics = None

        for i in range(self.num_epochs):
            train_loss, train_predictions, train_targets = self.train_epoch(train_loader)
            train_losses.append(train_loss)

            val_loss, val_predictions, val_targets = self.evaluate_epoch(val_loader)
            val_losses.append(val_loss)

            if self.early_stopping and i >= 50 and val_loss > np.mean(val_losses):
                stop = True

            if ((i + 1) % self.test_frequency == 0) or stop:
                test_loss, test_predictions, test_targets = self.evaluate_epoch(test_loader)
                test_losses.append(test_loss)

                _, tuned_threshold = self.calculate_metrics(val_predictions, val_targets)
                new_metrics = self.calculate_metrics(test_predictions, test_targets, threshold=tuned_threshold)
                best_metrics = self.update_best_metrics(best_metrics, new_metrics)

                self.summarize_epoch(i, train_loss, val_loss, test_loss, new_metrics, verbose=True)

                if new_metrics["MCC"] >= best_metrics["MCC"]:
                    torch.save(test_predictions, self.experiment_dir + "/predictions.pt")
                    torch.save(test_targets, self.experiment_dir + "/targets.pt")

            if stop:
                break

        best_results = self.summarize_epoch('BEST', min(train_losses), min(val_losses),  min(test_losses), best_metrics, verbose=True)

        return best_results

    @staticmethod
    def calculate_metrics(predictions, targets, threshold=None):
        ap = average_precision_score(targets, predictions)
        auc = roc_auc_score(targets, predictions)
        if threshold is None:
            tau = torch.linspace(0.1, 0.9, 9)
            y_hard = 1 * torch.ge(predictions, tau)
            phi, max_ind = torch.tensor([matthews_corrcoef(targets, y_hard[:, i]) for i in range(tau.size(0))]).max(dim=0)
            threshold = tau[max_ind]
            return {'APS': float(ap), 'AUC': float(auc), 'MCC': float(phi)}, float(threshold)
        else:
            y_hard = 1 * torch.ge(predictions, threshold)
            phi = matthews_corrcoef(targets, y_hard)
            return {'APS': float(ap), 'AUC': float(auc), 'MCC': float(phi)}

    @staticmethod
    def update_best_metrics(best, new):
        if best is None:
            best = new
        else:
            for i, (key, value) in enumerate(best.items()):
                if new[key] >= value:
                    best[key] = new[key]
        return best

    def summarize_epoch(self, epoch, train_loss, val_loss, test_loss, test_metrics, verbose=False):
        metrics = {
            'epoch': epoch,
            'Training Loss': train_loss,
            'Validation Loss': val_loss,
            'Test Loss': test_loss,
        }
        for i, (key, value) in enumerate(test_metrics.items()):
            metrics[key] = value
        if verbose:
            self.reporter(metrics)
        return metrics

    def save_model(self, model_dir):
        torch.save(self.model.state_dict(), model_dir)

    def load_model(self, model_dir):
        self.model.load_state_dict(torch.load(model_dir))